{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 扩展 ONNX 注册表\n",
    "\n",
    "参考：[扩展 ONNX 注册表](https://pytorch.org/tutorials//beginner/onnx/onnx_registry_tutorial.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本教程是对 ONNX 注册表的简介，它使用户能够实现新的 ONNX 算子，甚至用新实现替换现有算子。\n",
    "\n",
    "在模型导出到 ONNX 的过程中，PyTorch 模型被降低到由 [ATen 算子](https://pytorch.org/docs/stable/torch.compiler_ir.html)组成的中间表示。虽然 ATen 算子由 PyTorch 核心团队维护，但 ONNX 导出器团队负责独立地通过 [ONNX Script](https://onnxscript.ai/) 将这些算子实现到 ONNX。用户也可以用自己的实现替换 ONNX 导出器团队实现的行为，以修复错误或针对特定的 ONNX 运行时改进性能。\n",
    "\n",
    "ONNX 注册表管理 PyTorch 算子与 ONNX 算子对应项之间的映射，并提供 API 以扩展注册表。\n",
    "\n",
    "在本教程中，我们将涵盖三种需要使用自定义操作符扩展 ONNX 注册表的场景：\n",
    "\n",
    "- 不受支持的 ATen 算子\n",
    "- 具有现有 ONNX 运行时支持的自定义算子\n",
    "- 没有 ONNX 运行时支持的自定义算子"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 不受支持的 ATen 算子\n",
    "\n",
    "尽管 ONNX 导出器团队尽力支持所有的 ATen 算子，但其中一些可能尚未得到支持。在这一节中，我们将演示如何将不受支持的 ATen 算子添加到 ONNX 注册表中。\n",
    "\n",
    "```{note}\n",
    "实现不受支持的 ATen 算子的步骤与用自定义实现替换现有 ATen 算子的实现相同。由于我们在这个教程中实际上没有一个不受支持的 ATen 算子符可以使用，我们将利用这个机会，用自定义实现来替换 `aten::add.Tensor` 的实现，就像如果该算子不存在于 ONNX 注册表中一样。\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "当模型由于不受支持的操作符而无法导出到 ONNX 时，ONNX 导出器将显示类似于以下的错误消息：\n",
    "\n",
    "```\n",
    "RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.add.Tensor']}.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "错误消息指示不受支持的 ATen 算子的完全限定名称是 `aten::add.Tensor`。算子的完全限定名称由命名空间、算子名称和重载组成，格式为 `namespace::operator_name.overload`。\n",
    "\n",
    "要为不受支持的 ATen 算子添加支持或替换现有算子的实现，需要：\n",
    "\n",
    "- ATen 算子的完全限定名称（例如 `aten::add.Tensor`）。此信息总是如上所示的错误消息中出现。\n",
    "- 使用 ONNX Script 的算子实现。ONNX Script 是本教程的先决条件。\n",
    "- 因为 `aten::add.Tensor` 已经被 ONNX 注册表支持，我们将演示如何用自定义实现替换它，但请记住，同样的步骤适用于支持新的不受支持的 ATen 算子。\n",
    "\n",
    "这是可能的，因为 `OnnxRegistry` 允许用户覆盖算子注册。我们将用我们的自定义实现覆盖 `aten::add.Tensor` 的注册，并验证它的存在。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/tmp/cache/conda/envs/py311/lib/python3.11/site-packages/torch/onnx/_internal/exporter.py:137: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "aten::add.Tensor is supported by ONNX registry:       True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/tmp/cache/conda/envs/py311/lib/python3.11/site-packages/torch/onnx/_internal/exporter.py:137: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import onnxruntime\n",
    "import onnxscript\n",
    "from onnxscript import opset18  # opset 18 是目前最新的（也是唯一支持的）版本。\n",
    "\n",
    "class Model(torch.nn.Module):\n",
    "    def forward(self, input_x, input_y):\n",
    "        return torch.ops.aten.add(input_x, input_y)  # generates a aten::add.Tensor node\n",
    "\n",
    "input_add_x = torch.randn(3, 4)\n",
    "input_add_y = torch.randn(3, 4)\n",
    "aten_add_model = Model()\n",
    "\n",
    "\n",
    "# 现在，我们创建实现 `aten::add.Tensor` 的 ONNX Script 函数。\n",
    "# 函数名称（例如 `custom_aten_add`）会显示在 ONNX 图中，因此我们建议使用直观的名称。\n",
    "custom_aten = onnxscript.values.Opset(domain=\"custom.aten\", version=1)\n",
    "\n",
    "# NOTE：函数签名必须与不受支持的 ATen 算子的签名匹配。\n",
    "# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml\n",
    "# NOTE：所有属性必须用类型提示进行注解。\n",
    "@onnxscript.script(custom_aten)\n",
    "def custom_aten_add(input_x, input_y, alpha: float = 1.0):\n",
    "    alpha = opset18.CastLike(alpha, input_y)\n",
    "    input_y = opset18.Mul(input_y, alpha)\n",
    "    return opset18.Add(input_x, input_y)\n",
    "\n",
    "\n",
    "# 现在我们已经拥有支持不受支持的 ATen 算子所需的一切。\n",
    "# 让我们将 `custom_aten_add` 函数注册到 ONNX 注册表，并再次将模型导出到 ONNX。\n",
    "onnx_registry = torch.onnx.OnnxRegistry()\n",
    "onnx_registry.register_op(\n",
    "    namespace=\"aten\", op_name=\"add\", overload=\"Tensor\", function=custom_aten_add\n",
    "    )\n",
    "print(f\"aten::add.Tensor is supported by ONNX registry: \\\n",
    "      {onnx_registry.is_registered_op(namespace='aten', op_name='add', overload='Tensor')}\"\n",
    "      )\n",
    "export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)\n",
    "onnx_program = torch.onnx.dynamo_export(\n",
    "    aten_add_model, input_add_x, input_add_y, export_options=export_options\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在让我们检查模型，并验证模型是否使用了 `custom_aten_add` 而不是 `aten::add.Tensor`。图中有一个用于 `custom_aten_add` 的图节点，在其内部有四个函数节点，分别对应每个算子，以及一个用于常量属性的节点。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# graph node domain is the custom domain we registered\n",
    "assert onnx_program.model_proto.graph.node[0].domain == \"custom.aten\"\n",
    "assert len(onnx_program.model_proto.graph.node) == 1\n",
    "# graph node name is the function name\n",
    "assert onnx_program.model_proto.graph.node[0].op_type == \"custom_aten_add\"\n",
    "# function node domain is empty because we use standard ONNX operators\n",
    "assert onnx_program.model_proto.functions[0].node[3].domain == \"\"\n",
    "# function node name is the standard ONNX operator name\n",
    "assert onnx_program.model_proto.functions[0].node[3].op_type == \"Add\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这就是 `custom_aten_add_model` 在 ONNX 图中使用 Netron 的样子：\n",
    "![](images/custom_aten_add.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在 `custom_aten_add` 函数内部，我们可以看到在函数中使用的三个 ONNX 节点（`CastLike`、`Add` 和 `Mul`），以及常量属性：![](images/custon_aten_add_2.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这就是我们需要做的全部工作，以将新的 ATen 算子注册到 ONNX 注册表中。作为额外的一步，我们可以使用 ONNX Runtime 运行模型，并将结果与 PyTorch 进行比较。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use ONNX Runtime to run the model, and compare the results with PyTorch\n",
    "onnx_program.save(\"./custom_add_model.onnx\")\n",
    "ort_session = onnxruntime.InferenceSession(\n",
    "    \"./custom_add_model.onnx\", providers=['CPUExecutionProvider']\n",
    "    )\n",
    "\n",
    "def to_numpy(tensor):\n",
    "    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()\n",
    "\n",
    "onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_add_x, input_add_y)\n",
    "onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}\n",
    "onnxruntime_outputs = ort_session.run(None, onnxruntime_input)\n",
    "\n",
    "torch_outputs = aten_add_model(input_add_x, input_add_y)\n",
    "torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)\n",
    "\n",
    "assert len(torch_outputs) == len(onnxruntime_outputs)\n",
    "for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):\n",
    "    torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 具有现有 ONNX Runtime 支持的自定义算子\n",
    "在这种情况下，用户使用标准的 PyTorch 算子创建模型，但 ONNX 运行时（例如 Microsoft 的 ONNX Runtime）可以为该内核提供自定义实现，有效地替换 ONNX 注册表中的现有实现。另一个用例是，当用户想要使用现有 ONNX 算子的自定义实现来修复错误或提高特定算子的性能时。为了实现这一点，我们只需要将新实现注册到现有的 ATen 完全限定名称。\n",
    "\n",
    "在以下示例中，我们使用了来自 ONNX Runtime 的 `com.microsoft.Gelu`，它与 ONNX 规范中的 Gelu 不同。因此，我们用命名空间 `com.microsoft` 和算子名称 `Gelu` 注册了 Gelu。\n",
    "\n",
    "在我们开始之前，让我们检查 `aten::gelu.default` 是否真的受到 ONNX 注册表的支持。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "aten::gelu.default is supported by ONNX registry:     True\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/tmp/cache/conda/envs/py311/lib/python3.11/site-packages/torch/onnx/_internal/exporter.py:137: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "onnx_registry = torch.onnx.OnnxRegistry()\n",
    "print(f\"aten::gelu.default is supported by ONNX registry: \\\n",
    "    {onnx_registry.is_registered_op(namespace='aten', op_name='gelu', overload='default')}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在我们的示例中，`aten::gelu.default` 算子受到 ONNX 注册表的支持，所以 `onnx_registry.is_registered_op()` 返回 `True`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "'Gelu' is not a known op in 'com.microsoft'\n"
     ]
    }
   ],
   "source": [
    "class CustomGelu(torch.nn.Module):\n",
    "    def forward(self, input_x):\n",
    "        return torch.ops.aten.gelu(input_x)\n",
    "\n",
    "# com.microsoft is an official ONNX Runtime namspace\n",
    "custom_ort = onnxscript.values.Opset(domain=\"com.microsoft\", version=1)\n",
    "\n",
    "# NOTE: The function signature must match the signature of the unsupported ATen operator.\n",
    "# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml\n",
    "# NOTE: All attributes must be annotated with type hints.\n",
    "@onnxscript.script(custom_ort)\n",
    "def custom_aten_gelu(input_x, approximate: str = \"none\"):\n",
    "    # We know com.microsoft::Gelu is supported by ONNX Runtime\n",
    "    # It's only not supported by ONNX\n",
    "    return custom_ort.Gelu(input_x)\n",
    "\n",
    "\n",
    "onnx_registry = torch.onnx.OnnxRegistry()\n",
    "onnx_registry.register_op(\n",
    "    namespace=\"aten\", op_name=\"gelu\", overload=\"default\", function=custom_aten_gelu)\n",
    "export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)\n",
    "\n",
    "aten_gelu_model = CustomGelu()\n",
    "input_gelu_x = torch.randn(3, 3)\n",
    "\n",
    "onnx_program = torch.onnx.dynamo_export(\n",
    "    aten_gelu_model, input_gelu_x, export_options=export_options\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "让我们检查模型，并验证模型是否使用了 `custom_aten_gelu()` 而不是 `aten::gelu`。注意图中有一个用于 `custom_aten_gelu` 的图节点，在 `custom_aten_gelu` 内部，有一个命名空间为 `com.microsoft` 的 `Gelu` 函数节点。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# graph node domain is the custom domain we registered\n",
    "assert onnx_program.model_proto.graph.node[0].domain == \"com.microsoft\"\n",
    "# graph node name is the function name\n",
    "assert onnx_program.model_proto.graph.node[0].op_type == \"custom_aten_gelu\"\n",
    "# function node domain is the custom domain we registered\n",
    "assert onnx_program.model_proto.functions[0].node[0].domain == \"com.microsoft\"\n",
    "# function node name is the node name used in the function\n",
    "assert onnx_program.model_proto.functions[0].node[0].op_type == \"Gelu\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面的图表显示了使用 Netron 的 `custom_aten_gelu_model` ONNX 图：![](images/custom_aten_gelu.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在 `custom_aten_gelu` 函数内部，我们可以看到在函数中使用的来自模块 `com.microsoft` 的 `Gelu` 节点：![](images/custom_aten_gelu2.jpg)。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这就是我们需要做的全部工作。作为额外的步骤，我们可以使用 ONNX Runtime 运行模型，并将结果与 PyTorch 进行比较。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_program.save(\"./custom_gelu_model.onnx\")\n",
    "ort_session = onnxruntime.InferenceSession(\n",
    "    \"./custom_gelu_model.onnx\", providers=['CPUExecutionProvider']\n",
    "    )\n",
    "\n",
    "def to_numpy(tensor):\n",
    "    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()\n",
    "\n",
    "onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_gelu_x)\n",
    "onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}\n",
    "onnxruntime_outputs = ort_session.run(None, onnxruntime_input)\n",
    "\n",
    "torch_outputs = aten_gelu_model(input_gelu_x)\n",
    "torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)\n",
    "\n",
    "assert len(torch_outputs) == len(onnxruntime_outputs)\n",
    "for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):\n",
    "    torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 没有 ONNX Runtime 支持的自定义算子\n",
    "\n",
    "在这种情况下，算子不受任何 ONNX 运行时的支持，但我们希望将其作为自定义算子使用在 ONNX 图中。因此，我们需要在三个地方实现该算子：\n",
    "\n",
    "1. PyTorch FX 图\n",
    "2. ONNX 注册表\n",
    "3. ONNX 运行时\n",
    "\n",
    "在以下示例中，我们希望使用一个自定义算子，它接受一个张量输入，并返回一个输出。该算子将输入加到自身，并返回四舍五入的结果。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 在 PyTorch FX 图中注册自定义算子（Beta）\n",
    "\n",
    "首先，我们需要在 PyTorch FX 图中实现该算子。这可以通过使用 `torch._custom_op` 来完成。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: This is a beta feature in PyTorch, and is subject to change.\n",
    "from torch._custom_op import impl as custom_op\n",
    "\n",
    "@custom_op.custom_op(\"mylibrary::addandround_op\")\n",
    "def addandround_op(tensor_x: torch.Tensor) -> torch.Tensor:\n",
    "    ...\n",
    "\n",
    "@addandround_op.impl_abstract()\n",
    "def addandround_op_impl_abstract(tensor_x):\n",
    "    return torch.empty_like(tensor_x)\n",
    "\n",
    "@addandround_op.impl(\"cpu\")\n",
    "def addandround_op_impl(tensor_x):\n",
    "    return torch.round(tensor_x + tensor_x)  # add x to itself, and round the result\n",
    "\n",
    "torch._dynamo.allow_in_graph(addandround_op)\n",
    "\n",
    "class CustomFoo(torch.nn.Module):\n",
    "    def forward(self, tensor_x):\n",
    "        return addandround_op(tensor_x)\n",
    "\n",
    "input_addandround_x = torch.randn(3)\n",
    "custom_addandround_model = CustomFoo()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 在 ONNX 注册表中注册自定义算子\n",
    "对于步骤2和3，我们需要在 ONNX 注册表中实现该算子。在这个例子中，我们将使用命名空间 `test.customop` 和算子名称 `CustomOpOne` 以及 `CustomOpTwo` 在 ONNX 注册表中实现该算子。这两个算子已在 [`cpu_ops.cc`](https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc) 中注册并构建。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "'CustomOpOne' is not a known op in 'test.customop'\n",
      "'CustomOpTwo' is not a known op in 'test.customop'\n",
      "/media/pc/data/tmp/cache/conda/envs/py311/lib/python3.11/site-packages/torch/onnx/_internal/exporter.py:137: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "custom_opset = onnxscript.values.Opset(domain=\"test.customop\", version=1)\n",
    "\n",
    "# NOTE: The function signature must match the signature of the unsupported ATen operator.\n",
    "# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml\n",
    "# NOTE: All attributes must be annotated with type hints.\n",
    "@onnxscript.script(custom_opset)\n",
    "def custom_addandround(input_x):\n",
    "    # The same as opset18.Add(x, x)\n",
    "    add_x = custom_opset.CustomOpOne(input_x, input_x)\n",
    "    # The same as opset18.Round(x, x)\n",
    "    round_x = custom_opset.CustomOpTwo(add_x)\n",
    "    # Cast to FLOAT to match the ONNX type\n",
    "    return opset18.Cast(round_x, to=1)\n",
    "\n",
    "\n",
    "onnx_registry = torch.onnx.OnnxRegistry()\n",
    "onnx_registry.register_op(\n",
    "    namespace=\"mylibrary\", op_name=\"addandround_op\", overload=\"default\", function=custom_addandround\n",
    "    )\n",
    "\n",
    "export_options = torch.onnx.ExportOptions(onnx_registry=onnx_registry)\n",
    "onnx_program = torch.onnx.dynamo_export(\n",
    "    custom_addandround_model, input_addandround_x, export_options=export_options\n",
    "    )\n",
    "onnx_program.save(\"./custom_addandround_model.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`onnx_program` 通过 `onnx_program.model_proto` 将导出的模型暴露为 protobuf。图中有一个用于 `custom_addandround` 的图节点，在 `custom_addandround` 内部，有两个函数节点，每个算子一个。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert onnx_program.model_proto.graph.node[0].domain == \"test.customop\"\n",
    "assert onnx_program.model_proto.graph.node[0].op_type == \"custom_addandround\"\n",
    "assert onnx_program.model_proto.functions[0].node[0].domain == \"test.customop\"\n",
    "assert onnx_program.model_proto.functions[0].node[0].op_type == \"CustomOpOne\"\n",
    "assert onnx_program.model_proto.functions[0].node[1].domain == \"test.customop\"\n",
    "assert onnx_program.model_proto.functions[0].node[1].op_type == \"CustomOpTwo\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 在 ONNX Runtime 中注册自定义算子\n",
    "\n",
    "要将您的自定义算子库链接到 ONNX Runtime，您需要将 C++ 代码编译成共享库，并将其链接到 ONNX Runtime。请按照以下说明操作：\n",
    "\n",
    "1. 按照 [ONNX Runtime 的指令](https://pytorch.org/tutorials//beginner/onnx/%60https://github.com/microsoft/onnxruntime/blob/gh-pages/docs/reference/operators/add-custom-op.md)用 C++ 实现您的自定义算子。\n",
    "2. 从 [ONNX Runtime 发布页面](https://github.com/microsoft/onnxruntime/releases)下载 ONNX Runtime 源码分发版。\n",
    "3. 编译并将您的自定义算子库链接到 ONNX Runtime，例如：\n",
    "    ```bash\n",
    "    $ gcc -shared -o libcustom_op_library.so custom_op_library.cc -L /path/to/downloaded/ort/lib/ -lonnxruntime -fPIC\n",
    "    ```\n",
    "4. 使用 ONNX Runtime Python API 运行模型，并将结果与 PyTorch 进行比较。\n",
    "    ```python\n",
    "    ort_session_options = onnxruntime.SessionOptions()\n",
    "\n",
    "    # NOTE: Link the custom op library to ONNX Runtime and replace the path\n",
    "    # with the path to your custom op library\n",
    "    ort_session_options.register_custom_ops_library(\n",
    "        \"/path/to/libcustom_op_library.so\"\n",
    "    )\n",
    "    ort_session = onnxruntime.InferenceSession(\n",
    "        \"./custom_addandround_model.onnx\", providers=['CPUExecutionProvider'], sess_options=ort_session_options)\n",
    "\n",
    "    def to_numpy(tensor):\n",
    "        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()\n",
    "\n",
    "    onnx_input = onnx_program.adapt_torch_inputs_to_onnx(input_addandround_x)\n",
    "    onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}\n",
    "    onnxruntime_outputs = ort_session.run(None, onnxruntime_input)\n",
    "\n",
    "    torch_outputs = custom_addandround_model(input_addandround_x)\n",
    "    torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)\n",
    "\n",
    "    assert len(torch_outputs) == len(onnxruntime_outputs)\n",
    "    for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):\n",
    "        torch.testing.assert_close(torch_output, torch.tensor(onnxruntime_output))\n",
    "    ```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
